import torch
import random


def shuffle_indices(size, seed=None):
    if seed is not None:
        random.seed(seed)
    indices = list(range(size))
    random.shuffle(indices)
    return indices


def shuffle_tensors2(tensor, current_indices, target_indices):
    tensor_dict = {current_idx: t for current_idx,
                   t in zip(current_indices, tensor)}
    shuffled_tensors = [tensor_dict[current_idx]
                        for current_idx in target_indices]
    return torch.stack(shuffled_tensors)


def grid_to_list(tensor, grid_size):
    frame_count = len(tensor) * grid_size * grid_size
    flattened_list = [flatten_grid(grid.unsqueeze(
        0), [grid_size, grid_size]) for grid in tensor]
    list_tensor = torch.cat(flattened_list, dim=-2)
    return torch.cat(torch.chunk(list_tensor, frame_count, dim=-2), dim=0)


def list_to_grid(tensor, grid_size):
    grid_frame_count = grid_size * grid_size
    grid_count = len(tensor) // grid_frame_count
    flat_grids = [torch.cat([a for a in tensor[i * grid_frame_count:(i + 1)
                            * grid_frame_count]], dim=-2).unsqueeze(0) for i in range(grid_count)]
    unflattened_grids = [unflatten_grid(
        flat_grid, [grid_size, grid_size]) for flat_grid in flat_grids]
    return torch.cat(unflattened_grids, dim=0)


def flatten_grid(x, grid_shape):
    B, H, W, C = x.size()
    hs, ws = grid_shape
    img_h = H // hs
    flattened = torch.cat(torch.split(x, img_h, dim=1), dim=2)
    return flattened


def unflatten_grid(x, grid_shape):
    ''' 
    x: B x C x H x W
    '''
    B, H, W, C = x.size()
    hs, ws = grid_shape
    img_w = W // (ws)

    unflattened = torch.cat(torch.split(x, img_w, dim=2), dim=1)

    return unflattened
